Add and verify support for deterministic fp8 dpa/mha on SM100#2621
Add and verify support for deterministic fp8 dpa/mha on SM100#2621sudhakarsingh27 wants to merge 3 commits intoNVIDIA:mainfrom
deterministic fp8 dpa/mha on SM100#2621Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Greptile SummaryExtends deterministic FP8 attention support from FP16/BF16 (PR #2584) to FP8 data types on SM100 (Blackwell) architecture. The implementation threads the Key changes:
The version check at Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_mha_fp8_vs_f16/<br/>test_dpa_fp8_vs_f16
participant Utils as utils.py:<br/>get_attention_backend
participant CPP as fused_attn.cpp:<br/>nvte_fused_attn_bwd*
participant FP8 as fused_attn_fp8.cu:<br/>fused_attn_fp8_bwd
participant Impl as fused_attn_fp8.cu:<br/>fused_attn_fp8_bwd_impl_v1
participant cuDNN as cuDNN Backend
Note over Test: deterministic param<br/>added to tests
Test->>Utils: check backend with<br/>deterministic flag
Note over Utils: Filter: Allow FP8+deterministic<br/>only on SM100+ (arch >= 10.0)
Utils-->>Test: backend available
Test->>CPP: nvte_fused_attn_bwd_*<br/>(deterministic)
Note over CPP: QKV packed/<br/>KV packed/<br/>separate paths
CPP->>FP8: fused_attn_fp8_bwd<br/>(deterministic)
FP8->>Impl: fused_attn_fp8_bwd_impl_v1<br/>(deterministic)
Note over Impl: Check cuDNN version
alt cudnn_runtime_version >= 91900
Impl->>cuDNN: set_deterministic_algorithm<br/>(deterministic)
else version < 91900
Note over Impl: deterministic flag ignored
end
Impl->>cuDNN: execute backward pass
cuDNN-->>Impl: gradients
Impl-->>FP8: return
FP8-->>CPP: return
CPP-->>Test: return
|
| if (cudnn_runtime_version >= 91900) { | ||
| sdpa_backward_options.set_deterministic_algorithm(deterministic); | ||
| } |
There was a problem hiding this comment.
logic: Version check uses 91900 (cuDNN 9.19.0), but related PR #2584 and description mention 9.18.1+ requirement. Should this be 91810 instead?
| if (cudnn_runtime_version >= 91900) { | |
| sdpa_backward_options.set_deterministic_algorithm(deterministic); | |
| } | |
| if (cudnn_runtime_version >= 91810) { | |
| sdpa_backward_options.set_deterministic_algorithm(deterministic); | |
| } |
Is there a specific reason FP8 requires cuDNN 9.19.0+ while FP16/BF16 only needs 9.18.1+?
Description
Follow up for #2584 to add and verify support for "deterministic" fp8 dpa/mha cudnn attention kernels
Type of change
Changes
Please list the changes introduced in this PR:
deterministicargument throughfused_attn_fp8.cu.pytorch/attention/dot_product_attention/utils.pyto allow fp8 + deterministic kernels on SM100test_attention.pyto check fp8 withdeterministic=TrueChecklist: